{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "synthesize_demo",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJDJLE3v0HNr"
      },
      "source": [
        "# Fetch Codebase and Install Environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qy1nwGJV5JuG"
      },
      "source": [
        "import os\n",
        "os.chdir('/content')\n",
        "CODE_DIR = 'GenForce'\n",
        "!git clone https://github.com/genforce/genforce.git $CODE_DIR\n",
        "os.chdir(f'./{CODE_DIR}')\n",
        "!pip install -r requirements.txt > installation_output.txt"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qh5DFyyg0Ntm"
      },
      "source": [
        "# Define Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qcSdJW5V0M-8"
      },
      "source": [
        "import os\n",
        "import subprocess\n",
        "import io\n",
        "import IPython.display\n",
        "import numpy as np\n",
        "import PIL.Image\n",
        "\n",
        "import torch\n",
        "\n",
        "from models import MODEL_ZOO\n",
        "from models import build_generator\n",
        "from utils.visualizer import fuse_images\n",
        "\n",
        "\n",
        "def postprocess(images):\n",
        "  \"\"\"Post-processes images from `torch.Tensor` to `numpy.ndarray`.\"\"\"\n",
        "  images = images.detach().cpu().numpy()\n",
        "  images = (images + 1) * 255 / 2\n",
        "  images = np.clip(images + 0.5, 0, 255).astype(np.uint8)\n",
        "  images = images.transpose(0, 2, 3, 1)\n",
        "  return images\n",
        "\n",
        "\n",
        "def build(model_name):\n",
        "  \"\"\"Builds generator and load pre-trained weights.\"\"\"\n",
        "  model_config = MODEL_ZOO[model_name].copy()\n",
        "  url = model_config.pop('url')  # URL to download model if needed.\n",
        "\n",
        "  # Build generator.\n",
        "  print(f'Building generator for model `{model_name}` ...')\n",
        "  generator = build_generator(**model_config)\n",
        "  print(f'Finish building generator.')\n",
        "\n",
        "  # Load pre-trained weights.\n",
        "  os.makedirs('checkpoints', exist_ok=True)\n",
        "  checkpoint_path = os.path.join('checkpoints', model_name + '.pth')\n",
        "  print(f'Loading checkpoint from `{checkpoint_path}` ...')\n",
        "  if not os.path.exists(checkpoint_path):\n",
        "    print(f'  Downloading checkpoint from `{url}` ...')\n",
        "    subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])\n",
        "    print(f'  Finish downloading checkpoint.')\n",
        "  checkpoint = torch.load(checkpoint_path, map_location='cpu')\n",
        "  if 'generator_smooth' in checkpoint:\n",
        "    generator.load_state_dict(checkpoint['generator_smooth'])\n",
        "  else:\n",
        "    generator.load_state_dict(checkpoint['generator'])\n",
        "  generator = generator.cuda()\n",
        "  generator.eval()\n",
        "  print(f'Finish loading checkpoint.')\n",
        "  return generator\n",
        "\n",
        "\n",
        "def synthesize(generator, num, synthesis_kwargs=None, batch_size=1, seed=0):\n",
        "  \"\"\"Synthesize images.\"\"\"\n",
        "  assert num > 0 and batch_size > 0\n",
        "  synthesis_kwargs = synthesis_kwargs or dict()\n",
        "\n",
        "  # Set random seed.\n",
        "  np.random.seed(seed)\n",
        "  torch.manual_seed(seed)\n",
        "\n",
        "  # Sample and synthesize.\n",
        "  outputs = []\n",
        "  for idx in range(0, num, batch_size):\n",
        "    batch = min(batch_size, num - idx)\n",
        "    code = torch.randn(batch, generator.z_space_dim).cuda()\n",
        "    with torch.no_grad():\n",
        "      images = generator(code, **synthesis_kwargs)['image']\n",
        "      images = postprocess(images)\n",
        "    outputs.append(images)\n",
        "  return np.concatenate(outputs, axis=0)\n",
        "\n",
        "\n",
        "def imshow(images, viz_size=256, col=0, spacing=0):\n",
        "  \"\"\"Shows images in one figure.\"\"\"\n",
        "  fused_image = fuse_images(\n",
        "    images,\n",
        "    col=col,\n",
        "    image_size=viz_size,\n",
        "    row_spacing=spacing,\n",
        "    col_spacing=spacing\n",
        "  )\n",
        "  fused_image = np.asarray(fused_image, dtype=np.uint8)\n",
        "  data = io.BytesIO()\n",
        "  PIL.Image.fromarray(fused_image).save(data, 'jpeg')\n",
        "  im_data = data.getvalue()\n",
        "  disp = IPython.display.display(IPython.display.Image(im_data))\n",
        "  return disp"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIrseINa879H"
      },
      "source": [
        "# Select a Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RyoJmv-PtZo_"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "model_name = \"stylegan_diningroom256\" #@param ['pggan_celebahq1024', 'pggan_bedroom256', 'pggan_livingroom256', 'pggan_diningroom256', 'pggan_kitchen256', 'pggan_churchoutdoor256', 'pggan_tower256', 'pggan_bridge256', 'pggan_restaurant256', 'pggan_classroom256', 'pggan_conferenceroom256', 'pggan_person256', 'pggan_cat256', 'pggan_dog256', 'pggan_bird256', 'pggan_horse256', 'pggan_sheep256', 'pggan_cow256', 'pggan_car256', 'pggan_bicycle256', 'pggan_motorbike256', 'pggan_bus256', 'pggan_train256', 'pggan_boat256', 'pggan_airplane256', 'pggan_bottle256', 'pggan_chair256', 'pggan_pottedplant256', 'pggan_tvmonitor256', 'pggan_diningtable256', 'pggan_sofa256', 'stylegan_ffhq1024', 'stylegan_celebahq1024', 'stylegan_bedroom256', 'stylegan_cat256', 'stylegan_car512', 'stylegan_celeba_partial256', 'stylegan_ffhq256', 'stylegan_ffhq512', 'stylegan_livingroom256', 'stylegan_diningroom256', 'stylegan_kitchen256', 'stylegan_apartment256', 'stylegan_church256', 'stylegan_tower256', 'stylegan_bridge256', 'stylegan_restaurant256', 'stylegan_classroom256', 'stylegan_conferenceroom256', 'stylegan_animeface512', 'stylegan_animeportrait512', 'stylegan_artface512', 'stylegan2_ffhq1024', 'stylegan2_church256', 'stylegan2_cat256', 'stylegan2_horse256', 'stylegan2_car512']\n",
        "\n",
        "generator = build(model_name)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RsGPMc5E8_jn"
      },
      "source": [
        "# Synthesize"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jPkIXKxp4-7L"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "num_samples = 8 #@param {type:\"slider\", min:1, max:20, step:1}\n",
        "noise_seed = 488 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
        "truncation = 1 #@param {type:\"slider\", min:0.0, max:1, step:0.02}\n",
        "truncation_layers = 3 #@param {type:\"slider\", min:0, max:18, step:1}\n",
        "randomize_noise = 'false' #@param ['true', 'false']\n",
        "\n",
        "synthesis_kwargs = dict(trunc_psi=1 - truncation,\n",
        "             trunc_layers=truncation_layers,\n",
        "             randomize_noise=randomize_noise)\n",
        "images = synthesize(generator, num_samples, synthesis_kwargs, seed=noise_seed)\n",
        "imshow(images)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}